{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Affine X*W+B 层的实现\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "class Affine:\n",
    "    def __init__(self, W, b):\n",
    "        self.W = W\n",
    "        self.b = b\n",
    "        self.x = None\n",
    "        self.dW = None\n",
    "        self.db = None\n",
    "\n",
    "    def forward(self, x):\n",
    "        self.x = x\n",
    "        out = np.dot(x, self.W) + self.b\n",
    "        return out\n",
    "\n",
    "    def backward(self, dout):\n",
    "        dx = np.dot(dout, self.W.T)\n",
    "        self.dW = np.dot(self.x.T, dout)\n",
    "        self.db = np.sum(dout, axis=0)\n",
    "        return dx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [],
   "source": [
    "# Softmax-with-Loss层\n",
    "from common import functions\n",
    "\n",
    "class SoftmaxWithLoss:\n",
    "    def __init__(self):\n",
    "        self.loss = None    # 损失\n",
    "        self.y = None       # softmax输出\n",
    "        self.t = None       # 监督数据(one-hot vector)\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        self.t = t\n",
    "        self.y = functions.softmax(x)\n",
    "        self.loss = functions.cross_entropy_error(self.y, self.t)\n",
    "        return self.loss\n",
    "\n",
    "    def backward(self, dout=1):\n",
    "        batch_size = self.t.shape[0]\n",
    "        dx = (self.y - self.t) / batch_size\n",
    "        return dx\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}